import numpy as np
import pandas as pd
from environment import Plane

class RL(object):

    #初始化
    def __init__(self , actions_space , learning_rate = 0.9, reward_decay = 0.9 , e_greedy = 0.9 , step = 2):
        self.actions = actions_space
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy     #选择性修改：ε-greedy的选择
        self.sarsa_table = pd.DataFrame(columns = self.actions , dtype = np.float64)
        self.step = step

    # 2.2 找到奖励值最大的动作,跳转到下一个状态,并记录该状态下的延误值
    def choose_action(self , observation):
        pd.set_option('display.width', None)
        # 2.2.1 找到该状态下的延误值最小的动作，并拿到这个数据
        actions = self.sarsa_table[0].tolist()
        observation_actions = self.sarsa_table.loc[actions.index(observation), 1:]
        best_action = np.random.choice(observation_actions[observation_actions == np.min(observation_actions)].index)
        dtime = np.min(observation_actions)

        # 2.2.2 如果说找了一圈动作没有你的小，则选择"原地不动：动作"，并拿到下一个状态值

        if best_action == 37:
            observation = observation.copy()
        else:
            t_1 = observation.index(best_action) - 1
            t = observation.index(best_action)
            observation[t_1] , observation[t] = observation[t] , observation[t_1]
        return dtime , observation


    #检查并更新 Sarsa表
    def check_and_update(self , state):
        if state not in self.sarsa_table[0].tolist(): #改：新建sarsa_table的新建，对应飞机的动作，有几个写到几个。
            line = self.add_line(len(state) + 1 , state)
            self.sarsa_table = self.sarsa_table.append(line , ignore_index = True)

    def add_line(self , n , state):
        line = {}
        line[0] = state
        for i in range(1 , n):
            line[i] = 0.0
        return line


class SarsaTable(RL):
    #继承 RL模块初始化
    def __init__(self , actions , learning_rate = 0.9 , reward_decay = 0.9 , e_greedy = 0.9 , step = 2):
        super(SarsaTable , self).__init__(actions , learning_rate , reward_decay , e_greedy , step)

    # 2.1.2 拿到延误数据，写入Q表
    def writing_in_sarsa_table(self , action, dtime, ob):
        self.check_and_update(ob)
        all_observation = self.sarsa_table[0].tolist()
        index = all_observation.index(ob)
        self.sarsa_table.loc[index, action] = dtime

    # 2.3 根据动作重新更新一次Q表的策略
    def learn(self , observation, action, dtime, next_observation):
        self.check_and_update(next_observation)
        all_observation = self.sarsa_table[0].tolist()
        now_index = all_observation.index(observation)
        index = all_observation.index(next_observation)
        sarsa_predict = self.sarsa_table.loc[now_index , action]
        sarsa_target = dtime + self.gamma * self.sarsa_table.loc[index , 1:].min()
        self.sarsa_table.loc[now_index , action] += self.lr * (sarsa_target - sarsa_predict)
        #print(self.sarsa_table)

